# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for mcore flash_attn UT of inference."""
import numpy as np

NUM_BLOCKS = 128
BLOCK_SIZE = 16
BATCH_SIZE = 2
SEQ_LENGTH = [2, 1]
HIDDEN_SIZE = 64
NUM_HEADS = 2
max_seq_len = 3
np.random.seed(2025)


def get_init_params(n_kv_heads):
    """Generate initialization parameters"""
    query = np.random.normal(0, 0.01, (BATCH_SIZE, max_seq_len, NUM_HEADS, int(HIDDEN_SIZE / NUM_HEADS))).astype(
        np.float32)
    key = np.random.normal(0, 0.01, (BATCH_SIZE, max_seq_len, n_kv_heads, int(HIDDEN_SIZE / NUM_HEADS))).astype(
        np.float32)
    value = np.random.normal(0, 0.01, (BATCH_SIZE, max_seq_len, n_kv_heads, int(HIDDEN_SIZE / NUM_HEADS))).astype(
        np.float32)
    prefill_query = np.zeros((BATCH_SIZE, SEQ_LENGTH[0], NUM_HEADS, int(HIDDEN_SIZE / NUM_HEADS)))
    prefill_key = np.zeros((BATCH_SIZE, SEQ_LENGTH[0], n_kv_heads, int(HIDDEN_SIZE / NUM_HEADS)))
    prefill_value = np.zeros((BATCH_SIZE, SEQ_LENGTH[0], n_kv_heads, int(HIDDEN_SIZE / NUM_HEADS)))
    decoder_query = np.zeros((BATCH_SIZE, 1, NUM_HEADS, int(HIDDEN_SIZE / NUM_HEADS)))
    decoder_key = np.zeros((BATCH_SIZE, 1, n_kv_heads, int(HIDDEN_SIZE / NUM_HEADS)))
    decoder_value = np.zeros((BATCH_SIZE, 1, n_kv_heads, int(HIDDEN_SIZE / NUM_HEADS)))
    q_seq_lens = np.array([SEQ_LENGTH[0] + 1] * BATCH_SIZE, dtype=np.int32)
    kv_seq_lens = q_seq_lens
    slot_mapping_list = []

    for i, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
                                                    kv_seq_lens)):
        prefill_query[i, 0:(q_seq_len - 1), :, :] = query[i, 0:(q_seq_len - 1), :, :]
        prefill_key[i, 0:(kv_seq_len - 1):, :] = key[i, 0:(kv_seq_len - 1), :, :]
        prefill_value[i, 0:(kv_seq_len - 1), :, :] = value[i, 0:(kv_seq_len - 1), :, :]
        decoder_query[i, :, :, :] = query[i, (q_seq_len - 1):q_seq_len, :, :]
        decoder_key[i, :, :, :] = key[i, (kv_seq_len - 1):kv_seq_len, :, :]
        decoder_value[i, :, :, :] = value[i, (kv_seq_len - 1):kv_seq_len, :, :]

    for i in range(SEQ_LENGTH[0]):
        for j in range(SEQ_LENGTH[0] + 1):
            slot_mapping_list.append((2 - i) * BLOCK_SIZE + j)

    prefill_slot_mapping = []
    decoder_slot_mapping = []
    base_idx = 0
    for seq_len in q_seq_lens:
        prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx + seq_len - 1)])
        decoder_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
        base_idx += seq_len

    return {
        "prefill_query": prefill_query,
        "decoder_query": decoder_query,
        "prefill_key": prefill_key,
        "decoder_key": decoder_key,
        "prefill_value": prefill_value,
        "decoder_value": decoder_value,
        "prefill_slot_mapping": prefill_slot_mapping,
        "decoder_slot_mapping": decoder_slot_mapping
    }


def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    prefill_output_1 = np.array(
        [[4.78210440e-03, -1.52899837e-03, -5.56694623e-03,
          -9.00148973e-03, 1.05525330e-02, 7.28358328e-03,
          1.19536219e-03, -8.57044291e-03, 2.25798823e-02,
          1.52548915e-02, 1.14306314e-02, 4.22954559e-03,
          -2.65551009e-03, 7.20735872e-03, 2.88539450e-03,
          3.07936152e-03, -4.30210913e-03, -6.89400476e-05,
          -6.66053081e-03, 8.29378981e-03, -2.16212263e-03,
          2.21097488e-02, -5.35418140e-03, 1.68196810e-03,
          1.76066030e-02, -8.63478519e-03, 1.17711229e-02,
          -9.88265220e-03, 9.06197412e-04, -7.94475316e-04,
          -1.70679495e-03, 1.50222974e-02, 1.33759836e-02,
          7.18731852e-03, -8.20096117e-04, -1.66804511e-02,
          6.97347708e-03, -2.72352074e-04, -7.56404968e-03,
          2.37823208e-03, 2.21757265e-03, -1.64047768e-03,
          4.37840028e-03, 2.34601367e-02, -3.30534228e-03,
          1.11571671e-02, 2.31472589e-02, 4.72508650e-03,
          2.69948738e-03, 1.08594699e-02, 1.01232354e-03,
          4.17403504e-03, -1.32588865e-02, -6.92426809e-04,
          1.43721299e-02, 9.02891741e-04, 1.99090932e-02,
          9.56536445e-04, 8.33047926e-03, -7.71264359e-03,
          -1.71506195e-03, -3.52917006e-03, 1.03621709e-03,
          7.01961527e-03],
         [-2.60598026e-04, -6.53053727e-03, -2.47998862e-03,
          -9.50972317e-04, 1.30181350e-02, -2.29367753e-03,
          -8.61999393e-03, -6.84198365e-03, 1.80038344e-02,
          1.82520086e-03, 3.02014942e-03, 1.01187089e-02,
          -2.74236547e-03, 3.45974462e-03, 2.46873125e-03,
          5.44394460e-03, -2.88731046e-03, -9.21299681e-03,
          -3.90724745e-03, 5.94609324e-03, -1.52785946e-02,
          1.73709057e-02, -3.71670816e-03, -7.45223742e-03,
          6.49126340e-03, 4.40022349e-03, 9.83623974e-03,
          -4.13824199e-03, -7.70682527e-05, -1.13775302e-02,
          -1.08878221e-03, 1.25014214e-02, 2.39800988e-03,
          6.45176414e-03, -4.29697288e-03, -5.32770669e-03,
          -3.14920023e-03, 1.72498345e-03, -1.05702300e-02,
          -6.64277282e-03, -1.26713607e-03, -2.34883185e-03,
          2.74387212e-03, 6.88648783e-03, -4.03576251e-03,
          5.25689404e-03, 1.72633342e-02, 7.93139357e-03,
          1.89374713e-03, 1.28059750e-02, 1.23624195e-04,
          1.31471567e-02, 4.11688350e-04, -2.53668823e-03,
          4.04767040e-03, -7.75695778e-04, 1.54021569e-02,
          -3.50244809e-03, 7.13004451e-03, -2.34724605e-03,
          5.77936647e-03, -6.56210678e-03, -1.19833369e-03,
          -5.00915572e-04],
         [-8.66170134e-03, 8.89367249e-04, 1.22052734e-03,
          1.90485306e-02, 6.58875844e-03, -6.60980222e-05,
          6.56141015e-03, -1.38718328e-02, 3.81375384e-03,
          6.14774507e-03, -2.07991078e-02, -9.13245976e-03,
          -1.51165458e-03, -9.59596969e-03, -8.24277988e-04,
          -3.14953359e-04, -1.41623067e-02, -1.02188997e-05,
          5.00594405e-03, 1.49255637e-02, -1.47915343e-02,
          -1.01372618e-02, 5.25486190e-03, 1.95807219e-02,
          -5.13543049e-03, 1.43842271e-03, -4.97125275e-03,
          -2.13892888e-02, -9.90206469e-03, 7.96035398e-03,
          -1.97778898e-03, -4.22544731e-03, -1.58708356e-03,
          -4.82766191e-03, -1.60874762e-02, -1.27465371e-02,
          -3.06730741e-03, -3.94886918e-03, -1.61607768e-02,
          6.30674371e-03, 8.33656918e-03, 9.13738832e-03,
          1.09790522e-03, 1.07610943e-02, -3.94306425e-03,
          -1.28040351e-02, -9.44316853e-03, 2.24394873e-02,
          -4.79225488e-03, -1.11161508e-02, -1.96754161e-04,
          1.06743593e-02, 4.77723544e-03, 3.67596396e-04,
          8.62549897e-03, 6.14453049e-04, -1.15034794e-02,
          -8.24867934e-03, 7.25020654e-03, -1.35460952e-02,
          -4.77738259e-03, 6.77016983e-03, -5.92909381e-03,
          1.16246836e-02],
         [-7.60730496e-03, -3.38602788e-03, 7.32476404e-03,
          2.61355145e-03, -9.41273291e-03, -4.22585243e-03,
          3.00572719e-03, -6.72120461e-03, 6.42809784e-03,
          2.91007408e-03, -1.59442723e-02, -6.58231415e-03,
          -5.68541512e-03, -8.25291313e-03, 3.35798669e-03,
          -3.98758682e-04, -6.62942184e-03, 8.28395039e-03,
          3.56924534e-03, 1.66691653e-03, -1.10254809e-02,
          -4.78450442e-03, -4.91220597e-03, 1.53125795e-02,
          -6.43940596e-03, -5.27261710e-03, -1.81339367e-03,
          -2.13420112e-03, -2.71633337e-03, 9.79778171e-03,
          -2.38726032e-04, -4.13238443e-03, -1.36292586e-03,
          -8.60901084e-04, -4.73896973e-03, -1.11285932e-02,
          -4.51437663e-03, 3.21099418e-03, 3.82826664e-04,
          1.19104348e-02, 1.67473834e-02, 1.14163179e-02,
          -9.93292406e-03, 1.12861432e-02, -6.56289328e-03,
          -3.41419852e-03, -9.85774398e-03, 1.28035331e-02,
          3.16470629e-04, -2.07677973e-03, -6.87306095e-03,
          8.58728960e-03, 4.65855515e-03, -4.04392648e-03,
          1.40767731e-02, 2.19462207e-03, -8.67531914e-03,
          -5.63254021e-03, 7.34996516e-03, -1.26888342e-02,
          -9.32146329e-03, 5.43479482e-03, -8.32411461e-03,
          1.93351302e-02]], dtype=np.float32)

    decode_output_1 = np.array(
        [[3.61592928e-03, 2.61437707e-03, 6.08491362e-04,
          2.33785412e-03, 6.44481089e-03, 2.70417961e-03,
          -8.65238998e-03, -4.49167984e-03, 9.80007090e-03,
          2.04570708e-04, 3.86979105e-03, 1.13519430e-02,
          -5.57967182e-03, -2.86455732e-04, 2.42063543e-03,
          3.41859204e-03, -8.68803123e-04, -8.50711577e-03,
          -4.46806103e-03, 5.64815011e-04, -8.68324935e-03,
          9.73200798e-03, -2.93338206e-04, -2.61687580e-03,
          5.07365260e-03, 7.01112300e-03, 9.04510729e-03,
          1.71381002e-03, -2.60546361e-03, -4.27930243e-03,
          3.65901215e-04, 7.97689985e-03, 7.55381025e-03,
          1.87521474e-03, -2.77796504e-03, -2.91224336e-03,
          -2.83729937e-03, 3.86234652e-03, -1.68574275e-03,
          -7.93046318e-04, -4.62861685e-03, -4.20821737e-03,
          5.02487272e-03, 1.50676351e-03, -3.00203287e-03,
          -1.78001355e-05, 4.73647332e-03, -7.11526722e-04,
          -1.04203774e-03, 1.07188355e-02, 5.80536434e-04,
          9.69703682e-03, 1.67516945e-03, -1.93794572e-03,
          6.47268724e-03, -4.49661748e-05, 1.07991155e-02,
          3.13543063e-03, 3.69883608e-03, -7.14767608e-04,
          4.70202556e-03, 1.40629709e-04, 3.05922632e-03,
          -3.99756012e-03],
         [-2.13302812e-03, 9.17183701e-04, 8.52490123e-03,
          2.19240179e-03, -9.28870030e-03, -5.05185127e-03,
          4.85166581e-03, -6.47869566e-03, 2.28362158e-04,
          4.77675907e-03, -5.32109337e-03, -8.36760178e-03,
          -1.70376082e-03, -4.85174963e-03, 4.77672601e-03,
          4.24314232e-04, -2.63072643e-03, 6.54791435e-03,
          6.34494238e-03, 2.69650365e-03, -8.56713578e-03,
          -9.97176394e-03, -2.01292289e-03, 8.40762071e-03,
          -7.38070719e-03, -5.71429171e-03, 5.25554293e-04,
          4.17887000e-03, 3.34309624e-03, 1.23008788e-02,
          1.24979066e-03, -4.38148994e-03, -1.97624834e-03,
          -1.37303467e-03, -1.15559157e-03, -7.95578770e-03,
          -2.85348436e-03, 2.90999538e-03, 2.81144661e-04,
          1.91812404e-03, 1.29502267e-02, 1.03283189e-02,
          -5.03751636e-03, 7.73250591e-03, -1.03768352e-02,
          5.79087529e-03, -1.01380218e-02, 7.01679988e-03,
          -1.02382002e-03, 2.72996514e-03, -4.59825806e-03,
          5.37903886e-03, 2.16376409e-03, -3.60030471e-03,
          1.08234677e-02, 2.01319112e-03, -5.67783741e-03,
          1.21135591e-03, 4.46130941e-03, -2.16152612e-03,
          -9.13385302e-03, 7.63837248e-04, 2.71514803e-03,
          1.29194092e-02]], dtype=np.float32)

    prefill_output_2 = np.array(
        [[-2.0780e-02, 2.0475e-02, -1.3746e-02, 1.9507e-02, 3.6718e-04,
          8.1405e-04, 3.5358e-03, -1.5215e-03, 6.0567e-03, -3.1497e-03,
          -1.3619e-02, 3.8499e-03, 2.7327e-03, 4.5982e-04, 1.1575e-02,
          4.5783e-03, 5.2530e-03, 1.8155e-03, -8.7103e-03, -5.7384e-03,
          1.3325e-03, -6.5637e-03, 5.0603e-03, 6.2378e-03, 6.4293e-03,
          1.3622e-02, -8.1465e-03, -1.9629e-02, 6.3881e-04, 1.0041e-02,
          -6.5523e-03, -1.9953e-03, -2.0780e-02, 2.0475e-02, -1.3746e-02,
          1.9507e-02, 3.6718e-04, 8.1405e-04, 3.5358e-03, -1.5215e-03,
          6.0567e-03, -3.1497e-03, -1.3619e-02, 3.8499e-03, 2.7327e-03,
          4.5982e-04, 1.1575e-02, 4.5783e-03, 5.2530e-03, 1.8155e-03,
          -8.7103e-03, -5.7384e-03, 1.3325e-03, -6.5637e-03, 5.0603e-03,
          6.2378e-03, 6.4293e-03, 1.3622e-02, -8.1465e-03, -1.9629e-02,
          6.3881e-04, 1.0041e-02, -6.5523e-03, -1.9953e-03],
         [-1.0039e-02, 9.2820e-03, -4.3426e-03, 1.0261e-02, -6.5935e-03,
          9.2921e-03, 5.2382e-03, -3.7600e-04, 3.7519e-03, -6.4452e-03,
          -1.2528e-02, -2.3479e-03, 5.0990e-03, 8.9682e-03, 1.0474e-02,
          1.1032e-03, 5.9061e-03, 8.0319e-03, -7.7266e-03, 2.9235e-03,
          4.5516e-03, -7.2681e-03, 5.5776e-03, 1.4107e-02, 2.0945e-03,
          1.1987e-02, -5.2940e-03, -5.0924e-03, -3.1530e-03, 3.1974e-03,
          -1.4590e-03, 5.4575e-03, -1.0041e-02, 9.2836e-03, -4.3440e-03,
          1.0263e-02, -6.5925e-03, 9.2908e-03, 5.2380e-03, -3.7617e-04,
          3.7522e-03, -6.4448e-03, -1.2529e-02, -2.3470e-03, 5.0987e-03,
          8.9670e-03, 1.0474e-02, 1.1037e-03, 5.9060e-03, 8.0310e-03,
          -7.7267e-03, 2.9222e-03, 4.5512e-03, -7.2680e-03, 5.5776e-03,
          1.4106e-02, 2.0951e-03, 1.1987e-02, -5.2945e-03, -5.0944e-03,
          -3.1525e-03, 3.1984e-03, -1.4597e-03, 5.4564e-03],
         [1.0058e-03, -6.4866e-03, 3.3673e-03, -1.0064e-02, 1.0509e-02,
          1.1198e-02, -3.1837e-02, -1.9736e-03, -1.4347e-02, -9.6224e-04,
          -1.5505e-02, -2.4017e-02, -9.5637e-05, 3.2249e-03, 7.5859e-04,
          1.8615e-02, 7.0843e-03, -4.1924e-03, -1.0095e-02, -5.5779e-03,
          -5.0885e-03, 4.3815e-03, -4.2524e-03, 2.4045e-03, 4.0141e-03,
          1.2775e-03, -4.4678e-03, 4.0257e-03, 1.0731e-02, -2.0945e-02,
          1.7967e-02, -8.8001e-03, 1.0058e-03, -6.4866e-03, 3.3673e-03,
          -1.0064e-02, 1.0509e-02, 1.1198e-02, -3.1837e-02, -1.9736e-03,
          -1.4347e-02, -9.6224e-04, -1.5505e-02, -2.4017e-02, -9.5637e-05,
          3.2249e-03, 7.5859e-04, 1.8615e-02, 7.0843e-03, -4.1924e-03,
          -1.0095e-02, -5.5779e-03, -5.0885e-03, 4.3815e-03, -4.2524e-03,
          2.4045e-03, 4.0141e-03, 1.2775e-03, -4.4678e-03, 4.0257e-03,
          1.0731e-02, -2.0945e-02, 1.7967e-02, -8.8001e-03],
         [-6.7304e-03, 6.0323e-04, 2.8916e-03, -4.3857e-03, 8.1079e-03,
          8.7009e-03, -1.3418e-02, -7.3087e-04, -2.2669e-02, 2.7679e-03,
          -5.0918e-03, -1.5307e-02, -3.0946e-03, 3.9407e-03, -2.3452e-03,
          2.3273e-03, -3.4857e-03, -7.0373e-04, -3.9893e-03, 3.6511e-03,
          -7.3825e-03, 6.7118e-03, -2.4720e-03, -4.2788e-03, 3.2801e-03,
          3.2141e-03, 5.7742e-03, -3.5357e-03, -2.9317e-03, -6.0316e-03,
          6.1659e-03, -3.7697e-03, -6.7309e-03, 6.0374e-04, 2.8915e-03,
          -4.3853e-03, 8.1078e-03, 8.7008e-03, -1.3416e-02, -7.3078e-04,
          -2.2669e-02, 2.7682e-03, -5.0911e-03, -1.5307e-02, -3.0948e-03,
          3.9407e-03, -2.3455e-03, 2.3261e-03, -3.4864e-03, -7.0348e-04,
          -3.9889e-03, 3.6518e-03, -7.3827e-03, 6.7120e-03, -2.4719e-03,
          -4.2793e-03, 3.2801e-03, 3.2143e-03, 5.7750e-03, -3.5362e-03,
          -2.9327e-03, -6.0305e-03, 6.1650e-03, -3.7694e-03]], dtype=np.float32)

    decode_output_2 = np.array(
        [[-0.00210043, 0.00504539, 0.00043085, 0.00961853,
          0.00117281, 0.00502496, 0.00156756, 0.00424701,
          -0.00108411, -0.00596694, -0.01077254, -0.00023065,
          0.00102961, 0.00345692, 0.00880646, 0.00626904,
          0.00620982, 0.00755412, -0.00534197, 0.00316502,
          0.00213747, -0.00625003, -0.00035512, 0.00852213,
          0.00119731, 0.01115908, -0.00293302, -0.00395734,
          -0.00691232, -0.00368381, -0.00099835, 0.00240651,
          -0.00210073, 0.00504592, 0.00043046, 0.00961913,
          0.00117371, 0.00502414, 0.00156724, 0.00424718,
          -0.00108422, -0.00596669, -0.01077252, -0.0002301,
          0.00102922, 0.00345604, 0.00880645, 0.00626956,
          0.00620979, 0.00755367, -0.00534191, 0.00316444,
          0.00213711, -0.00624993, -0.00035548, 0.00852128,
          0.00119756, 0.01115914, -0.00293308, -0.00395827,
          -0.00691227, -0.00368372, -0.00099867, 0.00240583],
         [-0.0056486, 0.00073335, -0.00240162, 0.00679418,
          0.00419669, 0.00481898, -0.01241597, -0.00032082,
          -0.01371034, 0.00276547, -0.00421181, -0.01319709,
          -0.00074479, 0.00158098, -0.00613273, -0.00155136,
          -0.0040056, -0.00231432, -0.01203496, 0.00388645,
          -0.00457965, 0.00257275, -0.00091292, -0.00273539,
          0.00100923, 0.00107641, 0.00223413, -0.00134823,
          -0.00256657, -0.00167776, 0.00349989, -0.0041323,
          -0.00564966, 0.0007342, -0.00240114, 0.00679374,
          0.00419679, 0.00481907, -0.01241381, -0.00032071,
          -0.01371228, 0.00276592, -0.00421062, -0.01319623,
          -0.0007454, 0.00158131, -0.00613272, -0.00155296,
          -0.00400685, -0.00231373, -0.01203338, 0.00388756,
          -0.00458022, 0.00257346, -0.00091286, -0.00273637,
          0.00100937, 0.00107686, 0.00223575, -0.00134938,
          -0.00256828, -0.00167637, 0.00349871, -0.00413164]], dtype=np.float32)

    return {
        "prefill_output_1": prefill_output_1,
        "decode_output_1": decode_output_1,
        "prefill_output_2": prefill_output_2,
        "decode_output_2": decode_output_2,
    }


def get_gpu_data() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    prefill_output_1 = np.array(
        [[4.7913e-03, -1.5259e-03, -5.5542e-03, -8.9722e-03,
          1.0559e-02, 7.2937e-03, 1.1978e-03, -8.5449e-03,
          2.2583e-02, 1.5259e-02, 1.1414e-02, 4.2419e-03,
          -2.6550e-03, 7.2021e-03, 2.8839e-03, 3.0823e-03,
          -4.3030e-03, -6.9141e-05, -6.6528e-03, 8.3008e-03,
          -2.1667e-03, 2.2095e-02, -5.3406e-03, 1.6785e-03,
          1.7578e-02, -8.6060e-03, 1.1780e-02, -9.8877e-03,
          9.0790e-04, -7.9346e-04, -1.7090e-03, 1.5015e-02,
          1.3367e-02, 7.2021e-03, -8.2016e-04, -1.6724e-02,
          6.9885e-03, -2.7275e-04, -7.5684e-03, 2.3804e-03,
          2.2125e-03, -1.6403e-03, 4.3640e-03, 2.3438e-02,
          -3.3112e-03, 1.1169e-02, 2.3193e-02, 4.7302e-03,
          2.7008e-03, 1.0864e-02, 1.0147e-03, 4.1809e-03,
          -1.3245e-02, -6.9427e-04, 1.4343e-02, 9.0408e-04,
          1.9897e-02, 9.5749e-04, 8.3008e-03, -7.7209e-03,
          -1.7166e-03, -3.5248e-03, 1.0376e-03, 7.0190e-03],
         [-2.5940e-04, -6.5308e-03, -2.4719e-03, -9.3079e-04,
          1.3062e-02, -2.2736e-03, -8.6060e-03, -6.8359e-03,
          1.8066e-02, 1.8311e-03, 3.0060e-03, 1.0132e-02,
          -2.7466e-03, 3.4637e-03, 2.4719e-03, 5.4626e-03,
          -2.8839e-03, -9.2163e-03, -3.9062e-03, 5.9509e-03,
          -1.5320e-02, 1.7334e-02, -3.7079e-03, -7.4768e-03,
          6.5002e-03, 4.4250e-03, 9.8267e-03, -4.1504e-03,
          -7.6294e-05, -1.1414e-02, -1.0910e-03, 1.2512e-02,
          2.3804e-03, 6.4697e-03, -4.3030e-03, -5.3711e-03,
          -3.1281e-03, 1.7242e-03, -1.0559e-02, -6.6223e-03,
          -1.2741e-03, -2.3499e-03, 2.7313e-03, 6.8665e-03,
          -4.0283e-03, 5.2795e-03, 1.7334e-02, 7.9346e-03,
          1.8997e-03, 1.2817e-02, 1.2398e-04, 1.3123e-02,
          4.2725e-04, -2.5482e-03, 4.0283e-03, -7.7820e-04,
          1.5442e-02, -3.4943e-03, 7.1106e-03, -2.3499e-03,
          5.7678e-03, -6.5613e-03, -1.1978e-03, -4.8828e-04],
         [-8.6670e-03, 8.8882e-04, 1.2207e-03, 1.9043e-02,
          6.5918e-03, -6.6280e-05, 6.5613e-03, -1.3855e-02,
          3.8147e-03, 6.1340e-03, -2.0752e-02, -9.1553e-03,
          -1.5106e-03, -9.5825e-03, -8.2397e-04, -3.1471e-04,
          -1.4160e-02, -1.0192e-05, 5.0049e-03, 1.4954e-02,
          -1.4771e-02, -1.0132e-02, 5.2490e-03, 1.9531e-02,
          -5.1270e-03, 1.4420e-03, -4.9744e-03, -2.1362e-02,
          -9.8877e-03, 7.9346e-03, -1.9836e-03, -4.2114e-03,
          -1.5869e-03, -4.8218e-03, -1.6113e-02, -1.2756e-02,
          -3.0670e-03, -3.9368e-03, -1.6113e-02, 6.3171e-03,
          8.3618e-03, 9.1553e-03, 1.0986e-03, 1.0742e-02,
          -3.9368e-03, -1.2817e-02, -9.4604e-03, 2.2461e-02,
          -4.7913e-03, -1.1108e-02, -1.9646e-04, 1.0681e-02,
          4.7913e-03, 3.6812e-04, 8.6060e-03, 6.1417e-04,
          -1.1475e-02, -8.2397e-03, 7.2632e-03, -1.3550e-02,
          -4.7913e-03, 6.7749e-03, -5.9204e-03, 1.1597e-02],
         [-7.6294e-03, -3.3875e-03, 7.3242e-03, 2.6245e-03,
          -9.3994e-03, -4.2114e-03, 3.0060e-03, -6.7139e-03,
          6.4392e-03, 2.8992e-03, -1.5991e-02, -6.5918e-03,
          -5.7068e-03, -8.2397e-03, 3.3569e-03, -3.9864e-04,
          -6.6223e-03, 8.3008e-03, 3.5706e-03, 1.6785e-03,
          -1.1047e-02, -4.7913e-03, -4.9133e-03, 1.5320e-02,
          -6.4392e-03, -5.2490e-03, -1.8158e-03, -2.1362e-03,
          -2.7161e-03, 9.8267e-03, -2.4033e-04, -4.1199e-03,
          -1.3657e-03, -8.6212e-04, -4.7607e-03, -1.1169e-02,
          -4.5166e-03, 3.2196e-03, 4.2725e-04, 1.1902e-02,
          1.6724e-02, 1.1414e-02, -9.9487e-03, 1.1292e-02,
          -6.5613e-03, -3.4180e-03, -9.8877e-03, 1.2817e-02,
          3.2043e-04, -2.0752e-03, -6.8665e-03, 8.6060e-03,
          4.6692e-03, -4.0588e-03, 1.4099e-02, 2.1973e-03,
          -8.6670e-03, -5.6458e-03, 7.3547e-03, -1.2695e-02,
          -9.3384e-03, 5.4321e-03, -8.3618e-03, 1.9409e-02]], dtype=np.float16)

    decode_output_1 = np.array(
        [[3.6163e-03, 2.5940e-03, 6.0654e-04, 2.3651e-03,
          6.4697e-03, 2.7161e-03, -8.6670e-03, -4.4861e-03,
          9.8267e-03, 2.0599e-04, 3.8757e-03, 1.1353e-02,
          -5.5847e-03, -2.7847e-04, 2.4109e-03, 3.4332e-03,
          -8.6594e-04, -8.4839e-03, -4.4861e-03, 5.7983e-04,
          -8.7280e-03, 9.7656e-03, -2.8229e-04, -2.6398e-03,
          5.0659e-03, 7.0496e-03, 9.0332e-03, 1.7242e-03,
          -2.6093e-03, -4.3030e-03, 3.7003e-04, 7.9956e-03,
          7.5378e-03, 1.8768e-03, -2.7771e-03, -2.9297e-03,
          -2.8381e-03, 3.8757e-03, -1.6785e-03, -7.7820e-04,
          -4.6387e-03, -4.2114e-03, 5.0354e-03, 1.4801e-03,
          -3.0060e-03, -1.5259e-05, 4.7607e-03, -6.9427e-04,
          -1.0529e-03, 1.0742e-02, 5.8365e-04, 9.7046e-03,
          1.6785e-03, -1.9455e-03, 6.4697e-03, -4.5776e-05,
          1.0803e-02, 3.1281e-03, 3.7079e-03, -7.2098e-04,
          4.6997e-03, 1.3733e-04, 3.0823e-03, -3.9673e-03],
         [-2.1667e-03, 9.0790e-04, 8.5449e-03, 2.1973e-03,
          -9.2773e-03, -5.0659e-03, 4.8828e-03, -6.5002e-03,
          2.3651e-04, 4.7607e-03, -5.3101e-03, -8.4229e-03,
          -1.7090e-03, -4.8523e-03, 4.7913e-03, 4.2725e-04,
          -2.6398e-03, 6.5918e-03, 6.3477e-03, 2.7161e-03,
          -8.6060e-03, -1.0010e-02, -2.0142e-03, 8.4229e-03,
          -7.3853e-03, -5.7068e-03, 5.2643e-04, 4.1809e-03,
          3.3569e-03, 1.2329e-02, 1.2589e-03, -4.3945e-03,
          -1.9836e-03, -1.3733e-03, -1.1597e-03, -7.9956e-03,
          -2.8534e-03, 2.9297e-03, 3.3379e-04, 1.8921e-03,
          1.3000e-02, 1.0376e-02, -5.0659e-03, 7.7515e-03,
          -1.0376e-02, 5.7983e-03, -1.0132e-02, 7.0496e-03,
          -1.0223e-03, 2.7313e-03, -4.6082e-03, 5.4016e-03,
          2.1820e-03, -3.6163e-03, 1.0864e-02, 2.0142e-03,
          -5.6763e-03, 1.2207e-03, 4.4861e-03, -2.1667e-03,
          -9.1553e-03, 7.5531e-04, 2.7313e-03, 1.2939e-02]], dtype=np.float16)

    prefill_output_2 = np.array(
        [[-2.0782e-02, 2.0477e-02, -1.3748e-02, 1.9501e-02, 3.6716e-04,
          8.1396e-04, 3.5362e-03, -1.5211e-03, 6.0577e-03, -3.1490e-03,
          -1.3618e-02, 3.8490e-03, 2.7332e-03, 4.5991e-04, 1.1574e-02,
          4.5776e-03, 5.2528e-03, 1.8158e-03, -8.7128e-03, -5.7373e-03,
          1.3323e-03, -6.5651e-03, 5.0621e-03, 6.2370e-03, 6.4278e-03,
          1.3618e-02, -8.1482e-03, -1.9623e-02, 6.3896e-04, 1.0040e-02,
          -6.5536e-03, -1.9951e-03, -2.0782e-02, 2.0477e-02, -1.3748e-02,
          1.9501e-02, 3.6716e-04, 8.1396e-04, 3.5362e-03, -1.5211e-03,
          6.0577e-03, -3.1490e-03, -1.3618e-02, 3.8490e-03, 2.7332e-03,
          4.5991e-04, 1.1574e-02, 4.5776e-03, 5.2528e-03, 1.8158e-03,
          -8.7128e-03, -5.7373e-03, 1.3323e-03, -6.5651e-03, 5.0621e-03,
          6.2370e-03, 6.4278e-03, 1.3618e-02, -8.1482e-03, -1.9623e-02,
          6.3896e-04, 1.0040e-02, -6.5536e-03, -1.9951e-03],
         [-1.0040e-02, 9.2850e-03, -4.3411e-03, 1.0262e-02, -6.5918e-03,
          9.2926e-03, 5.2376e-03, -3.7599e-04, 3.7518e-03, -6.4468e-03,
          -1.2527e-02, -2.3479e-03, 5.1003e-03, 8.9645e-03, 1.0475e-02,
          1.1034e-03, 5.9052e-03, 8.0338e-03, -7.7248e-03, 2.9240e-03,
          4.5509e-03, -7.2670e-03, 5.5771e-03, 1.4107e-02, 2.0943e-03,
          1.1986e-02, -5.2948e-03, -5.0926e-03, -3.1528e-03, 3.1967e-03,
          -1.4591e-03, 5.4588e-03, -1.0040e-02, 9.2850e-03, -4.3449e-03,
          1.0262e-02, -6.5918e-03, 9.2926e-03, 5.2376e-03, -3.7622e-04,
          3.7518e-03, -6.4430e-03, -1.2527e-02, -2.3479e-03, 5.1003e-03,
          8.9645e-03, 1.0475e-02, 1.1034e-03, 5.9052e-03, 8.0338e-03,
          -7.7286e-03, 2.9221e-03, 4.5509e-03, -7.2670e-03, 5.5771e-03,
          1.4107e-02, 2.0943e-03, 1.1986e-02, -5.2948e-03, -5.0926e-03,
          -3.1528e-03, 3.1986e-03, -1.4601e-03, 5.4550e-03],
         [1.0061e-03, -6.4850e-03, 3.3665e-03, -1.0063e-02, 1.0506e-02,
          1.1200e-02, -3.1830e-02, -1.9741e-03, -1.4343e-02, -9.6226e-04,
          -1.5503e-02, -2.4017e-02, -9.5665e-05, 3.2253e-03, 7.5865e-04,
          1.8616e-02, 7.0839e-03, -4.1924e-03, -1.0094e-02, -5.5771e-03,
          -5.0888e-03, 4.3831e-03, -4.2534e-03, 2.4052e-03, 4.0131e-03,
          1.2779e-03, -4.4670e-03, 4.0245e-03, 1.0735e-02, -2.0950e-02,
          1.7960e-02, -8.7967e-03, 1.0061e-03, -6.4850e-03, 3.3665e-03,
          -1.0063e-02, 1.0506e-02, 1.1200e-02, -3.1830e-02, -1.9741e-03,
          -1.4343e-02, -9.6226e-04, -1.5503e-02, -2.4017e-02, -9.5665e-05,
          3.2253e-03, 7.5865e-04, 1.8616e-02, 7.0839e-03, -4.1924e-03,
          -1.0094e-02, -5.5771e-03, -5.0888e-03, 4.3831e-03, -4.2534e-03,
          2.4052e-03, 4.0131e-03, 1.2779e-03, -4.4670e-03, 4.0245e-03,
          1.0735e-02, -2.0950e-02, 1.7960e-02, -8.7967e-03],
         [-6.7291e-03, 6.0320e-04, 2.8915e-03, -4.3869e-03, 8.1100e-03,
          8.6975e-03, -1.3420e-02, -7.3099e-04, -2.2675e-02, 2.7676e-03,
          -5.0926e-03, -1.5305e-02, -3.0937e-03, 3.9406e-03, -2.3460e-03,
          2.3270e-03, -3.4866e-03, -7.0381e-04, -3.9902e-03, 3.6507e-03,
          -7.3814e-03, 6.7101e-03, -2.4719e-03, -4.2801e-03, 3.2806e-03,
          3.2139e-03, 5.7755e-03, -3.5362e-03, -2.9316e-03, -6.0310e-03,
          6.1646e-03, -3.7689e-03, -6.7291e-03, 6.0368e-04, 2.8915e-03,
          -4.3869e-03, 8.1100e-03, 8.6975e-03, -1.3412e-02, -7.3099e-04,
          -2.2675e-02, 2.7676e-03, -5.0926e-03, -1.5305e-02, -3.0956e-03,
          3.9406e-03, -2.3460e-03, 2.3270e-03, -3.4866e-03, -7.0333e-04,
          -3.9902e-03, 3.6526e-03, -7.3814e-03, 6.7139e-03, -2.4719e-03,
          -4.2801e-03, 3.2806e-03, 3.2139e-03, 5.7755e-03, -3.5362e-03,
          -2.9335e-03, -6.0310e-03, 6.1646e-03, -3.7689e-03]], dtype=np.float16)

    decode_output_2 = np.array(
        [[-0.00209, 0.005066, 0.0004425, 0.00964, 0.00119,
          0.005035, 0.001579, 0.004272, -0.001083, -0.00598,
          -0.0108, -0.0002365, 0.001038, 0.003448, 0.00885,
          0.006287, 0.006226, 0.00757, -0.00537, 0.003174,
          0.002136, -0.006287, -0.000351, 0.008545, 0.001205,
          0.01117, -0.00293, -0.003967, -0.006927, -0.003693,
          -0.001007, 0.002396, -0.00209, 0.005066, 0.0004425,
          0.00964, 0.00119, 0.005035, 0.001579, 0.004272,
          -0.001083, -0.00598, -0.0108, -0.0002365, 0.001038,
          0.003448, 0.00885, 0.006287, 0.006226, 0.00757,
          -0.00537, 0.003174, 0.002136, -0.006287, -0.000351,
          0.008545, 0.001205, 0.01117, -0.00293, -0.003967,
          -0.006927, -0.003693, -0.001007, 0.002396],
         [-0.005646, 0.0007286, -0.002396, 0.006836, 0.00421,
          0.00482, -0.01245, -0.0003185, -0.013794, 0.002777,
          -0.004242, -0.013245, -0.0007477, 0.001579, -0.006134,
          -0.001587, -0.00403, -0.00232, -0.012024, 0.003906,
          -0.00461, 0.002579, -0.000908, -0.002747, 0.001022,
          0.001083, 0.002228, -0.001358, -0.002579, -0.001694,
          0.003479, -0.00412, -0.005646, 0.0007286, -0.002396,
          0.006836, 0.00421, 0.00482, -0.01245, -0.0003185,
          -0.013794, 0.002777, -0.004242, -0.013245, -0.0007477,
          0.001579, -0.006134, -0.001587, -0.00403, -0.00232,
          -0.012024, 0.003906, -0.00461, 0.002579, -0.000908,
          -0.002747, 0.001022, 0.001083, 0.002228, -0.001358,
          -0.002579, -0.001694, 0.003479, -0.00412]], dtype=np.float16)

    return {
        "prefill_output_1": prefill_output_1,
        "decode_output_1": decode_output_1,
        "prefill_output_2": prefill_output_2,
        "decode_output_2": decode_output_2,
    }


GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_data()
